import jax
import jax.numpy as jnp
import json
import argparse
import os, sys

root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', 'sae-jax'))  
sys.path.insert(0, root)

root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', 'lsh'))  
sys.path.insert(0, root)

root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', 'sae-softmax'))  
sys.path.insert(0, root)

from lsh import GemmaLSH

from sae_save_load import (
    save_model, 
    load_model, 
    save_checkpoint, 
    load_checkpoint,
    save_metadata,
    load_jax_sae_to_pytorch,
    encode_sparse_torch
)


from functools import partial
import numpy as np
from typing import Union, List

import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM
from dual import GemmaEmbeddingPredictor

from eval_misc import process_batch, get_sparse_representations_and_reconstructions, find_top_k_embeddings_cosine_similarity, SparseCodeMatcher, compute_total_variation_distance, compute_top_k_overlap, quarter_sentence
from datasets import load_dataset


def parse_args():
    parser = argparse.ArgumentParser(description='Evaluate SAE-aware token prediction')
    parser.add_argument('--sae_model_path', type=str, 
                      default='~/gemma-7b-sae/k5_whole_sae_final_model.pkl',
                      help='Path to the SAE model')
    parser.add_argument('--sae_code_path', type=str,
                      default='~/gemma-7b-sae/k5_whole_sae_final_z.npy',
                      help='Path to the SAE code')
    parser.add_argument('--mlp_model_path', type=str,
                      default="~/dual-map/dual_map_mlp_model_small_rescaled_experimental.pt",
                      help='Path to the MLP model')
    parser.add_argument('--top_k_next_token', type=int, default=10,
                      help='Number of top k tokens to consider')
    parser.add_argument('--tok_n_candidates', type=int, default=5000,
                      help='Number of token candidates')
    parser.add_argument('--model_name', type=str, default="google/gemma-7b",
                      help='Name of the base model to use')
    parser.add_argument('--dataset_name', type=str, default="bookcorpus",
                      help='Name of the dataset to use')
    parser.add_argument('--total_samples', type=int, default=1000,
                      help='Total number of samples to evaluate')
    parser.add_argument('--save_every', type=int, default=50,
                      help='Save results every N samples')
    parser.add_argument('--cache_dir', type=str, default="~/gemma_cache",
                      help='Directory to cache model files')
    parser.add_argument('--whitening', type=bool, default=False,
                      help='Whether to use whitening')
    parser.add_argument('--lsh_only', action='store_true',
                      help='Only run LSH-related evaluations')
    parser.add_argument('--no_lsh', action='store_true',
                      help='Skip LSH-related evaluations')
    parser.add_argument('--num_hash_tables', type=int, default=8,
                      help='Number of hash tables to use in LSH')
    parser.add_argument('--num_hash_functions', type=int, default=4,
                      help='Number of hash functions per table in LSH')
    parser.add_argument('--output_dir', type=str, 
                      default=None,
                      help='Base directory to save output files. If not specified, will use ~/results/sae-softmax/{model_name}')
    parser.add_argument('--no_load_sae_weights', action='store_true',
                      help='Disable loading SAE weights when loading the model (weights are loaded by default)')

    return parser.parse_args()

# Parse command line arguments
args = parse_args()

# Set default output directory if not specified
if args.output_dir is None:
    model_name_short = args.model_name.split('/')[-1]  # Get just the model name without org
    args.output_dir = f"~/results/sae-softmax/{model_name_short}"

# Create output directory if it doesn't exist
os.makedirs(args.output_dir, exist_ok=True)

sae_model = load_jax_sae_to_pytorch(args.sae_model_path, load_weights=not args.no_load_sae_weights)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
gemma_tokenizer = AutoTokenizer.from_pretrained(args.model_name, cache_dir=args.cache_dir)
gemma_model = AutoModelForCausalLM.from_pretrained(args.model_name, cache_dir=args.cache_dir).to(device)

vocab_dict = gemma_tokenizer.get_vocab()
vocab_list = ["<unused>"] * (max(vocab_dict.values()) + 1)
for word, index in vocab_dict.items():
    vocab_list[index] = word
    
z = np.load(args.sae_code_path)
output_embeddings = gemma_model.get_output_embeddings().weight.to(device)
input_dim = output_embeddings.shape[1]

class NextToken:
    def __init__(self, z, output_embeddings, mlp_model_path, sae_model, top_n_canidates=4000):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.sparsecode = SparseCodeMatcher(z)
        self.original_g = output_embeddings.to(self.device)
        self.input_dim = self.original_g.shape[1]
    
        self.mean = self.original_g.mean(axis=0)
        original_g_centered = self.original_g - self.mean
        u, s, vt = torch.linalg.svd(original_g_centered, full_matrices=False)
        
        self.whitening_matrix = torch.matmul(
                torch.matmul(vt.T, torch.diag(1.0 / torch.sqrt(s + 1e-6))),
                vt
            )

        self.embedpredictor = GemmaEmbeddingPredictor(input_dim=self.input_dim, mlp_model_path=mlp_model_path)
        self.sae_model = sae_model
        self.top_n_canidates = top_n_canidates

    def get_next_logits(self, 
                        last_token_embedding, 
                        method = "default", top_candidates=200000):

        with torch.no_grad():
            last_token_embedding = last_token_embedding.to(self.device)
            next_token_logits = last_token_embedding @ self.original_g.T
            
            if method == "default":
                next_token_logits = last_token_embedding @ self.original_g.T
                return next_token_logits
    
            if method == "full" or method == "full+sparselookup":
                next_token_probs = torch.softmax(next_token_logits, dim=0)
                expected_unembedding = torch.matmul(next_token_probs, self.original_g)
                
                test_embedding = (expected_unembedding - self.mean) @ self.whitening_matrix
                test_embedding = test_embedding.detach()
                
            if method == "approxi" or method == "approxi+sparselookup":
                test_embedding = self.embedpredictor.predict_next_token_embedding(last_token_embedding.to(self.device))
                
            test_z = encode_sparse_torch(self.sae_model, test_embedding)
                
            if method == "approxi" or method == "full":
                top_indices_cpu = find_top_k_embeddings_cosine_similarity(z, test_z, k=self.top_n_canidates)
            if method == "approxi+sparselookup":
                top_indices_cpu = self.sparsecode.retrieve_similar_codes(test_z, max_codes=top_candidates)

            device = next_token_logits.device
            # top_idx = top_indices_cpu.to(device)                      # move to GPU
            mask = torch.ones_like(next_token_logits, dtype=torch.bool)
            mask[top_indices_cpu] = False                                     # False = keep, True = mask out
            
            # set masked positions to -inf
            next_token_logits = next_token_logits.masked_fill(mask, float("-inf"))
    
            return next_token_logits, len(top_indices_cpu)
                                

nextokenpred = NextToken(z, output_embeddings, args.mlp_model_path, sae_model)

ds = load_dataset(args.dataset_name, split="train", trust_remote_code=True)
    
total_samples = args.total_samples
save_every = args.save_every  # Save every N samples

# Initialize metrics based on what we're evaluating
tv_distances = []
if not args.no_lsh:
    tv_distances_lsh = []
    tv_distances_lsh_dual = []
    if args.whitening:
        tv_distances_lsh_whitening = []
        tv_distances_lsh_whitening_dual = []

top_k_overlaps = np.zeros((total_samples, 20))
len_top_indices_cpu_list = []

if not args.no_lsh:
    top_k_overlaps_lsh = np.zeros((total_samples, 20))
    top_k_overlaps_lsh_dual = np.zeros((total_samples, 20))
    if args.whitening:
        top_k_overlaps_lsh_whitening = np.zeros((total_samples, 20))
        top_k_overlaps_lsh_whitening_dual = np.zeros((total_samples, 20))

# Initialize LSH only if needed
if not args.no_lsh:
    lsh = GemmaLSH(gemma_model, gemma_tokenizer, num_hash_tables=args.num_hash_tables, num_hash_functions=args.num_hash_functions)
    if args.whitening:
        print("Whitening is True")
        lsh_whitening = GemmaLSH(gemma_model, gemma_tokenizer, whitening=True, num_hash_tables=args.num_hash_tables, num_hash_functions=args.num_hash_functions)

from tqdm import tqdm

cnt = 0
for sentence in tqdm(ds["text"]):
    # print(sentence)
    prompt = quarter_sentence(sentence)
    inputs = gemma_tokenizer(prompt, return_tensors="pt").to(device)
    
    with torch.no_grad():
        outputs = gemma_model(**inputs, output_hidden_states=True)
    
    last_hidden_state = outputs.hidden_states[-1]
    last_token_embedding = last_hidden_state[0, -1, :]
    
    default_logits = nextokenpred.get_next_logits(last_token_embedding, method="default")
    
    if not args.lsh_only:
        approxi_logits, len_top_indices_cpu = nextokenpred.get_next_logits(last_token_embedding, method="approxi+sparselookup", top_candidates=args.tok_n_candidates) 
        len_top_indices_cpu_list.append(len_top_indices_cpu)
        tv_dist = compute_total_variation_distance(approxi_logits, default_logits)
        tv_distances.append(tv_dist.item())

    if not args.no_lsh:
        lsh_logits = lsh.predict_next_token(prompt, top_candidates=args.tok_n_candidates, dual_transform=False, return_original=False, return_logits=True)
        lsh_logits_dual = lsh.predict_next_token(prompt, top_candidates=args.tok_n_candidates, dual_transform=True, return_original=False, return_logits=True)

        if args.whitening:
            lsh_whitening_logits = lsh_whitening.predict_next_token(prompt, top_candidates=args.tok_n_candidates, dual_transform=False, return_original=False, return_logits=True)
            lsh_whitening_logits_dual = lsh_whitening.predict_next_token(prompt, top_candidates=args.tok_n_candidates, dual_transform=True, return_original=False, return_logits=True)
        
        tv_dist_lsh = compute_total_variation_distance(lsh_logits, default_logits)
        tv_distances_lsh.append(tv_dist_lsh.item())

        tv_dist_lsh_dual = compute_total_variation_distance(lsh_logits_dual, default_logits)
        tv_distances_lsh_dual.append(tv_dist_lsh_dual.item())

        if args.whitening:
            tv_dist_lsh_whitening = compute_total_variation_distance(lsh_whitening_logits, default_logits)
            tv_distances_lsh_whitening.append(tv_dist_lsh_whitening.item())
            tv_dist_lsh_whitening_dual = compute_total_variation_distance(lsh_whitening_logits_dual, default_logits)
            tv_distances_lsh_whitening_dual.append(tv_dist_lsh_whitening_dual.item())

    test_array = np.zeros(20)
    if not args.lsh_only:
        for k in range(1, 21):
            overlap = compute_top_k_overlap(default_logits, approxi_logits, k)
            test_array[k-1] = overlap
        top_k_overlaps[cnt, :] = test_array

    if not args.no_lsh:
        test_array_lsh = np.zeros(20)
        test_array_lsh_dual = np.zeros(20)
        for k in range(1, 21):
            overlap_lsh = compute_top_k_overlap(default_logits, lsh_logits, k)
            test_array_lsh[k-1] = overlap_lsh
            
            overlap_lsh_dual = compute_top_k_overlap(default_logits, lsh_logits_dual, k)
            test_array_lsh_dual[k-1] = overlap_lsh_dual

        top_k_overlaps_lsh[cnt, :] = test_array_lsh
        top_k_overlaps_lsh_dual[cnt, :] = test_array_lsh_dual

        if args.whitening:
            test_array_lsh_whitening = np.zeros(20)
            test_array_lsh_whitening_dual = np.zeros(20)
            for k in range(1, 21):
                overlap_lsh_whitening = compute_top_k_overlap(default_logits, lsh_whitening_logits, k)
                test_array_lsh_whitening[k-1] = overlap_lsh_whitening
                
                overlap_lsh_whitening_dual = compute_top_k_overlap(default_logits, lsh_whitening_logits_dual, k)
                test_array_lsh_whitening_dual[k-1] = overlap_lsh_whitening_dual

            top_k_overlaps_lsh_whitening[cnt, :] = test_array_lsh_whitening
            top_k_overlaps_lsh_whitening_dual[cnt, :] = test_array_lsh_whitening_dual
    
    print(prompt)
    if not args.lsh_only:
        print(test_array)
    if not args.no_lsh:
        print(test_array_lsh_dual)
        if args.whitening:
            print(test_array_lsh_whitening_dual)

    print(tv_dist_lsh_whitening_dual, tv_dist_lsh_dual)

    cnt += 1

    if cnt >= total_samples:
        break

# Final save
if not args.lsh_only:
    tv_distances = np.array(tv_distances)
    top_k_overlaps = np.array(top_k_overlaps)
    
    # Add suffix to filenames if weights weren't loaded
    suffix = "_no_weights" if args.no_load_sae_weights else ""
    
    np.save(os.path.join(args.output_dir, f'tv_distances_sae{suffix}.npy'), tv_distances)
    np.save(os.path.join(args.output_dir, f'top_k_overlaps_sae{suffix}.npy'), top_k_overlaps)
    np.save(os.path.join(args.output_dir, f'len_top_indices_cpu{suffix}.npy'), np.array(len_top_indices_cpu_list))
    print(np.array(len_top_indices_cpu_list).mean(), np.array(len_top_indices_cpu_list).std())

if not args.no_lsh:
    tv_distances_lsh = np.array(tv_distances_lsh)   
    tv_distances_lsh_dual = np.array(tv_distances_lsh_dual)
    top_k_overlaps_lsh = np.array(top_k_overlaps_lsh)
    top_k_overlaps_lsh_dual = np.array(top_k_overlaps_lsh_dual)
   
    np.save(os.path.join(args.output_dir, f'tv_distances_lsh_t{args.num_hash_tables}_f{args.num_hash_functions}.npy'), tv_distances_lsh)
    np.save(os.path.join(args.output_dir, f'tv_distances_lsh_dual_t{args.num_hash_tables}_f{args.num_hash_functions}.npy'), tv_distances_lsh_dual)
    np.save(os.path.join(args.output_dir, f'top_k_overlaps_lsh_t{args.num_hash_tables}_f{args.num_hash_functions}.npy'), top_k_overlaps_lsh)
    np.save(os.path.join(args.output_dir, f'top_k_overlaps_lsh_dual_t{args.num_hash_tables}_f{args.num_hash_functions}.npy'), top_k_overlaps_lsh_dual)

    if args.whitening:
        tv_distances_lsh_whitening = np.array(tv_distances_lsh_whitening)
        tv_distances_lsh_whitening_dual = np.array(tv_distances_lsh_whitening_dual)
        np.save(os.path.join(args.output_dir, f'tv_distances_lsh_whitening_t{args.num_hash_tables}_f{args.num_hash_functions}.npy'), tv_distances_lsh_whitening)
        np.save(os.path.join(args.output_dir, f'tv_distances_lsh_whitening_dual_t{args.num_hash_tables}_f{args.num_hash_functions}.npy'), tv_distances_lsh_whitening_dual)

        top_k_overlaps_lsh_whitening = np.array(top_k_overlaps_lsh_whitening)
        top_k_overlaps_lsh_whitening_dual = np.array(top_k_overlaps_lsh_whitening_dual)
        np.save(os.path.join(args.output_dir, f'top_k_overlaps_lsh_whitening_t{args.num_hash_tables}_f{args.num_hash_functions}.npy'), top_k_overlaps_lsh_whitening)
        np.save(os.path.join(args.output_dir, f'top_k_overlaps_lsh_whitening_dual_t{args.num_hash_tables}_f{args.num_hash_functions}.npy'), top_k_overlaps_lsh_whitening_dual)